/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/** \file

    Contains the necessary definitions to use grid factories.
    \ingroup mcrx */

#ifndef __grid_factory__
#define __grid_factory__

#include <iostream>
#include "grid.h"
#include "refinement_accuracy_data.h"
#include "boost/thread/thread.hpp"
#include "boost/thread/mutex.hpp"
#include "hilbert.h"
#include "tbb/concurrent_hash_map.h"
#define TBB_PREVIEW_CONCURRENT_PRIORITY_QUEUE 1
#include "tbb/concurrent_priority_queue.h"
#include "mcrx-debug.h"

namespace mcrx {
  class cell_code_level_predicate;
  class cell_code_hash_compare;
  template<typename, typename> class grid_creator;
};




/** Grid constructor which uses a grid factory to build the grid.  */
template <typename cell_data_type>
template <typename factory_type>
mcrx::adaptive_grid<cell_data_type>::
adaptive_grid (const vec3d & mi, const vec3d & ma,
	       factory_type& f):
  min_(mi), max_(ma), c_(mpi_rank()), subgrid_block_ (0), data_block_(0)
{
  // The grid creator does the work (in the constructor)
  grid_creator<adaptive_grid<cell_data_type>, factory_type> gc (*this, f);

  domain_.push_back(hilbert::cell_code(1,0));
  domainstart_.push_back(0);
  domainstart_.push_back(n_cells());
};


struct mcrx::cell_code_level_predicate {
  bool operator()(const hilbert::cell_code lhs, 
		  const hilbert::cell_code rhs) const {
    return lhs.level() < rhs.level(); };
};

struct mcrx::cell_code_hash_compare {
  static bool equal(const hilbert::cell_code lhs, 
		    const hilbert::cell_code rhs) {
    return lhs==rhs; };
  static size_t hash(const hilbert::cell_code c) {
    return size_t(c.code()); };
};

/** Creates a grid using a grid factory to obtain the grid data.
    Grid_creator is a utility object that has the same function as the
    nbody_data_grid when it comes to doing refinement.  (In fact,
    makegrid could have been written to use this mechanism instead,
    but it predates the grid factory functionality.)  \ingroup mcrx */
template <typename grid_type, typename factory_type>
class mcrx::grid_creator {
public:
  typedef grid_type T_grid;
  typedef typename T_grid::T_data T_data;
  typedef typename T_grid::T_cell T_cell;
  typedef typename T_grid::T_cell_tracker T_cell_tracker;
  typedef factory_type T_factory;
  typedef typename T_cell_tracker::T_code T_code;
  typedef refinement_accuracy_data<T_data> T_racc;

private:
  typedef tbb::concurrent_hash_map<T_code, T_racc, 
				   cell_code_hash_compare> T_racc_map;

  /// \name Functions for grid building and refinement.
  ///@{
  void build_grid ();
  T_racc recursive_refine (T_cell_tracker& c);
  T_racc recursive_refine_body (T_cell_tracker& c);
  void unrefine_if_possible (const T_cell_tracker& c, 
			     T_racc& racc);
  ///@}

  T_factory& factory_;
  grid_type& g_;

  /// Done flag set by the thread that completes cell 0.
  volatile bool done_; 

  /// Atomic counter for the number of idle threads.
  tbb::atomic<int> idle_;

  /// Queue containing codes of cells to be processed.
  tbb::concurrent_priority_queue<T_code, cell_code_level_predicate> cell_queue_;

  /** Hash map containing the refinement_accuracy_data objects of
      completed cell_queue requests. */
  T_racc_map racc_map_;

  /// \name Threading data structures.
  ///@{

  /// Function object used to start threads.
  class thread_start; 

  void pop_cell_and_refine ();
  ///@}


public:
  grid_creator (T_grid& g, T_factory& f);
};

template <typename grid_type, typename factory_type>
mcrx::grid_creator<grid_type, factory_type>::
grid_creator (grid_type& g, T_factory& f):
  factory_ (f), g_(g), done_(false) { 
  idle_=0;
  build_grid ();
}


/** Functor used by boost_threads to start the threads. Simply calls
    pop_cell_and_refine () for the grid factory. \ingroup mcrx */
template <typename grid_type, typename factory_type>
class mcrx::grid_creator<grid_type, factory_type>::thread_start {
private:
  grid_creator<grid_type, factory_type>* self;
public:
  thread_start (grid_creator<grid_type, factory_type>* g): self (g) {};
  void operator () () {
    self->pop_cell_and_refine ();
  };
};

  
/** Accumulate the refinement_accuracy_data over the 8 child cells.
    (c points to a refined cell, and when this function returns it
    will point to the same cell.) */
template <typename grid_type, typename factory_type>
mcrx::refinement_accuracy_data<typename grid_type::T_data>
mcrx::grid_creator<grid_type, factory_type>::recursive_refine (T_cell_tracker& c)
{
  assert(!c.is_leaf());
  
  const T_code code = c.code();

  // descend to first child
  c.dfs_bidir();

  // process the children. we always do oct 7 ourselves since it
  // doesn't make sense to farm out all work on other threads and then
  // sit idle
  blitz::TinyVector<bool,8> processed(true);
  T_racc racc;
  bool first=true;
  for(int i=0; i<8; ++i) {
    assert(c.code().level()==code.level()+1);
    if(i<7 && decrement_if_greater_than_zero(idle_)) {
      // waiting threads, push this cell on the queue. Make sure we
      // increment the tracker BEFORE pushing, as otherwise it will
      // descend when another thread refines it.
      DEBUG(1,std::stringstream s; s<< "Idle threads, pushing code " << c.code() << '\n'; std::cout << s.str(););
      const T_code cc(c.code());
      c.dfs_bidir();
      cell_queue_.push(cc);
      processed[i]=false;
    }
    else {
      // no waiting threads (or oct 7), process it
      if(first) {
	racc = recursive_refine_body(c);
	first=false;
      }
      else 
	racc+= recursive_refine_body (c);
      c.dfs_bidir();
    }
  }

  
  // are we waiting on cells to be processed?
  while(!all(processed)) {
    // busy-wait for all results to appear.
    for(int i=0; i<8; ++i) {
      if(!processed[i]) {
	typename T_racc_map::const_accessor a;
	T_code newcode(code);
	newcode.add_right(i);
	if(racc_map_.find(a, newcode)) {
	  // found it. add it to racc and remove from map
	  DEBUG(1,std::stringstream s; s<< "Got racc result for code " << newcode << '\n'; std::cout << s.str(););

	  racc += a->second;
	  racc_map_.erase(a);
	  processed[i]=true;
	}
      }
    }
  }

  // check that we are back at parent
  assert(c.code()==code);

  return racc;
}


/** Unrefine the current cell if the refinement_accuracy_data
    indicates that it should. */
template <typename grid_type, typename factory_type>
void
mcrx::grid_creator<grid_type, factory_type>::
unrefine_if_possible (const T_cell_tracker& c,
		      T_racc& racc)
{
  assert (!c.is_leaf());

  const bool p =factory_.unify_cell_p (c, racc);

  if (p) {
    T_data*const d =  new T_data (T_data::unification (racc.sum, racc.n)) ;
    c.cell()->unrefine(d);
  }
  else
    // Mark that the sub cells were not unified
    racc.all_leaves = false;
}


/** This function is called by the thread_start object for each of the
    threads. It takes a cell from the queue and processes it. */
template <typename grid_type, typename factory_type>
void 
mcrx::grid_creator<grid_type, factory_type>::pop_cell_and_refine()
{
  std::cout << "thread starting\n";
  T_cell_tracker current_cell(g_);
  while (true) {

    // pop a code off the queue. if queue is empty, increase idle
    // counter and busy wait for some work.
    T_code code;
    bool popped = cell_queue_.try_pop(code);
    if(!popped) {
      DEBUG(1,std::stringstream s; s<< "Thread idle\n"; std::cout << s.str(););
      ++idle_;
      while(!popped) {
	if(done_) return;
	popped = cell_queue_.try_pop(code);
      }
      // we don't decrease the counter, because that's done by the
      // thread that pushes the code onto the queue. this avoids
      // unnecessary pushing.
    }
    assert(popped);
    DEBUG(1,std::stringstream s; s<< "Thread popped code " << code << '\n'; std::cout << s.str(););

    // have code. locate that cell and process it
    current_cell.restart();
    current_cell.locate(code);
    typename T_racc_map::value_type v(code, 
				      recursive_refine_body (current_cell));
    DEBUG(1,std::stringstream s; s<< "Thread pushed racc for code " << code << '\n'; std::cout << s.str(););
    const bool inserted=racc_map_.insert(v);
    assert(inserted);

  } // while 

  std::cout << "thread exiting\n";
}


/** Checks if the current cell should be refined and either processes
    the refinement of calculates the data of the leaf cell. */
template <typename grid_type, typename factory_type>
typename mcrx::refinement_accuracy_data<typename grid_type::T_data>
mcrx::grid_creator<grid_type, factory_type>::recursive_refine_body (T_cell_tracker& c)
{
  assert (c->is_leaf());
  const T_code code =c.code();

  // Check if we should refine the cell
  if (factory_.refine_cell_p(c)) {
    // yes, refine and descend 
    c->refine();
    const T_code code = c.code();
    T_racc racc(recursive_refine (c));

    if (racc.all_leaves) {
      // the grid contains only leaf cells.  We can try to unrefine
      // them
      unrefine_if_possible (c, racc);
    }
    if(c.code()==T_code(0,0))
      done_=true;
    assert(code==c.code());
    return racc;
  }
  else {
    // no, don't refine.  Get the cell data from the factory 
    c.cell()->set_data(new T_data (factory_.get_data(c)));
    if(c.code()==T_code(0,0))
      done_=true;
    return T_racc (*c.cell()->data (),
		   *c.cell()->data()* *c.cell()->data(), 
		   1, true);
  }
}


/** This function is responsible for actually building the grid. It just starts the threads and puts the root cell on the queue.*/
template <typename grid_type, typename factory_type>
void mcrx::grid_creator<grid_type, factory_type>::
build_grid ()
{
  int n_threads = factory_.n_threads();

  T_cell_tracker c(g_);
  c.restart();

  if (n_threads>1) {
    // spawn threads
    boost::thread_group threads;
    for (int i = 0; i < n_threads; ++i)
      //the thread_start object calls pop_cell_and_refine ()
      threads.create_thread(thread_start (this));

    while(idle_<n_threads);
    // threads are now rearing to go. decrease the idle counter by one
    // and push the item on
    if(decrement_if_greater_than_zero(idle_)) {
      cell_queue_.push(c.code());
    }
    else {
      std::cerr << "wtf??\n";
    }

    threads.join_all();
  }
  else {// no threads
    cell_queue_.push(c.code());
    pop_cell_and_refine ();
  }

  std::cout << "Final grid contains " << g_.n_cells () << " cells" << std::endl;
}

#endif
